// 到这个网址可以在线编码，实操tensorflow.js：https://chn.ai/tf.html


const d = 768
const tokens = {'每天': 0, '吃': 1, '苹果': 2, '可以': 3, '补充': 4, '维生素': 5, '和': 6, '纤维': 7, '对': 8, '身体': 9, '有': 10, '好处': 11}	

const a = tf.tensor1d(Object.keys(tokens).map(k => tokens[k]))
    

const embedding = tf.layers.embedding({inputDim: 200, outputDim: d})
const x = embedding.apply(a)
console.log('x', x.shape)

const [dQ, dK, dV] = [24, 24, 30]

const wQuery = tf.randomUniform([dQ, d])
const wKey = tf.randomUniform([dK, d])
const wValue = tf.randomUniform([dV, d])


const querys = wQuery.matMul(x, false, true).transpose()
const keys = wKey.matMul(x, false, true).transpose()
const values = wValue.matMul(x, false, true).transpose()

console.log('querys', querys.shape)
console.log('keys', keys.shape)
console.log('values', values.shape)

const os = querys.matMul(keys, false, true)
const weights = os.div(tf.sqrt(dK)).softmax()
console.log('weights', weights.shape)

const atts = weights.matMul(values)
console.log(atts.shape)